{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "BIedd7Pz9sOs"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<stable_baselines3.ppo.ppo.PPO at 0x7f800c1f85e0>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from stable_baselines3 import PPO\n",
    "\n",
    "#verbose: (int) Verbosity level 0: not output 1: info 2: debug\n",
    "model = PPO('MlpPolicy', 'CartPole-v1', verbose=0)\n",
    "\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "4oPTHjxyZSOL"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(38.25, 8.624818838677134)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from stable_baselines3.common.evaluation import evaluate_policy\n",
    "\n",
    "#测试,前一个数是reward_sum_mean,后一个数是reward_sum_std\n",
    "evaluate_policy(model, model.get_env(), n_eval_episodes=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "e4cfSXIB-pTF"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080\"> 100%</span> <span style=\"color: #729c1f; text-decoration-color: #729c1f\">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span style=\"color: #008000; text-decoration-color: #008000\">20,368/20,000 </span> [ <span style=\"color: #808000; text-decoration-color: #808000\">0:00:26</span> &lt; <span style=\"color: #008080; text-decoration-color: #008080\">0:00:00</span> , <span style=\"color: #800000; text-decoration-color: #800000\">756 it/s</span> ]\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[35m 100%\u001b[0m \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m20,368/20,000 \u001b[0m [ \u001b[33m0:00:26\u001b[0m < \u001b[36m0:00:00\u001b[0m , \u001b[31m756 it/s\u001b[0m ]\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<stable_baselines3.ppo.ppo.PPO at 0x7f800c1f85e0>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#训练\n",
    "model.learn(total_timesteps=2_0000, progress_bar=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ygl_gVmV_QP7"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(490.8, 23.24779559442142)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluate_policy(model, model.get_env(), n_eval_episodes=20)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "1_getting_started.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:cuda117]",
   "language": "python",
   "name": "conda-env-cuda117-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
